# -- encoding = 'utf-8' --
"""
描述|Description:
      辅助函数定义

作者|Author(s):
      lc

版权|Copyright:
      lc 所有权利保留。lc ALL RIGHTS RESERVED.

文档|Document:

    提供连接sqlserver数据库一些相关工具函数

更新|update:
      < 2022/05/05 李畅 (Li Chang)>

"""
from typing import List, Tuple
from docutils import ApplicationError
from jinja2 import Environment, PackageLoader, select_autoescape
import pymssql
import os
import json
import datetime


class ConnectionHelper:
    """
    数据库辅助连接类型
    """

    def __init__(self, server: str, user: str, password: str, charset: str, database: str, port: str) -> None:
        self._server = server
        self._user = user
        self._password = password
        self._charset = charset
        self._database = database
        self._connection = None
        self._port = port

    def __enter__(self):
        self._connection = pymssql.connect(
            host=self._server, database=self._database, user=self._user, password=self._password, charset=self._charset, port=self._port)

    def __exit__(self, exc_type, exc_val, exc_tb):
        if self._connection is not None:
            self._connection.close()

    def __str__(self) -> str:
        return f'{self._server}:{self._database}'

    @property
    def connection(self):
        """
        连接对象
        """
        return self._connection


def manual_connect() -> Tuple[ConnectionHelper, ConnectionHelper]:
    """
    手动连接服务器
    """
    connect_master = manual_connect_master()
    connect_mirror = manual_connect_mirror()
    return (connect_master, connect_mirror)


def manual_connect_master():
    """
    手动连接主服务器
    """
    netaddress = None
    while netaddress is None:
        netaddress = input(f'主服务器地址(a[bort/中止]):')

    user = None
    while user is None:
        user = input(f'主服务器用户(a[bort/中止]):')

    password = None
    while password is None:
        password = input(f'主服务器密码(a[bort/中止]):')

    database = 'master'

    connect_helper = ConnectionHelper(
        server=netaddress, user=user, password=password, database=database, charset='utf8', port='1433')

    return connect_helper


def manual_connect_mirror():
    """
    手动连接镜像服务器
    """
    netaddress = None
    while netaddress is None:
        netaddress = input(f'镜像服务器地址(a[bort/中止]):')

    user = None
    while user is None:
        user = input(f'镜像服务器用户(a[bort/中止]):')

    password = None
    while password is None:
        password = input(f'镜像服务器密码(a[bort/中止]):')

    database = 'master'

    connect_helper = ConnectionHelper(
        server=netaddress, user=user, password=password, database=database, charset='utf8', port='1433')

    return connect_helper


def connect_from_json(json_file: str) -> Tuple[ConnectionHelper, ConnectionHelper]:
    """
    从json文件中连接
    """
    with open(json_file, 'r') as config_file:
        try:
            config_data = json.loads(config_file.read())
            master_config = config_data['master']
            mirror_config = config_data['mirror']

            connect_master = ConnectionHelper(server=master_config['address'], user=master_config['user'], password=master_config['password'],
                                              charset=master_config['charset'], database=master_config['database'], port=master_config['port'])
            connect_mirror = ConnectionHelper(server=mirror_config['address'], user=mirror_config['user'], password=mirror_config['password'],
                                              charset=mirror_config['charset'], database=mirror_config['database'], port=mirror_config['port'])

            return connect_master, connect_mirror
        except Exception as e:
            print(f'解析配置文件"{json_file}"失败')
            raise e


def save_connect(connect_master: ConnectionHelper, connect_mirror: ConnectionHelper, json_file: str):
    """
    配置保存到json文件
    """
    connect_config = {'master': None, 'mirror': None}
    master_config = {'address': connect_master._server, 'user': connect_master._user,
                     'password': connect_master._password, 'database': connect_master._database, 'charset': connect_master._charset, 'port': connect_master._port}
    mirror_config = {'address': connect_mirror._server, 'user': connect_mirror._user,
                     'password': connect_mirror._password, 'database': connect_mirror._database, 'charset': connect_mirror._charset, 'port': connect_mirror._port}
    connect_config['master'] = master_config
    connect_config['mirror'] = mirror_config

    with open(json_file, 'w') as config_file:
        json.dump(connect_config, fp=config_file)


def master_outbound_setup(connect_master: ConnectionHelper):
    """
    主服务器的出站设置
    """
    master_key_password = None
    cert_dir_path = 'c:\\cer\\'
    cert_name = 'master_cert'
    cert_file_path = os.path.join(cert_dir_path, f'{cert_name}.cer')

    if not os.path.isdir(cert_dir_path):
        os.mkdir(cert_dir_path)

    if os.path.isfile(cert_file_path):
        answer = None
        while answer not in ('y', 'Y'):
            answer = input(f'{cert_file_path}文件已经存在，必须删除才能继续,是否删除?(y/n)')
        if answer in ('y', 'Y'):
            os.remove(cert_file_path)
        else:
            raise ApplicationError(f'无法产生主服务器的出站设置sql')

    #
    # 检查主加密密钥是否存在
    if exists_master_key(connect_master):
        print(f'主服务器主加密密钥已存在')
    else:
        master_key_password = input(f'输入主加密密钥的口令:')
    endpoint_name = 'Endpoint_Mirroring'
    port = 5022
    #
    # 产生实际的sql文件

    #
    # 产生主服务器出站sql
    master_outbound_sql = get_outbound_sql('主', '镜像', master_key_password=master_key_password,
                                           cert_name=cert_name, endpoint_name=endpoint_name, port=port, cert_file_path=cert_file_path)
    #
    # 写入文件
    with open('step1_master_outbound.sql', 'w') as tmp_file:
        tmp_file.write(master_outbound_sql)

    print(
        f'主服务器出站sql文件已产生,请以sa用户登录主服务器数据库实例，执行“step1_master_outbound.sql”,然后将产生的{cert_file_path}拷贝到镜像服务器的{cert_dir_path}下，并在主服务器的防火墙允许{port}端口，再执行下面的步骤2')

    return cert_file_path


def mirror_outbound_setup(connect_mirror: ConnectionHelper):
    """
    镜像服务器的出站设置
    """
    master_key_password = None
    cert_dir_path = 'c:\\cer\\'
    cert_name = 'mirror_cert'
    cert_file_path = os.path.join(cert_dir_path, f'{cert_name}.cer')

    if not os.path.isdir(cert_dir_path):
        os.mkdir(cert_dir_path)

    if os.path.isfile(cert_file_path):
        answer = None
        while answer not in ('y', 'Y'):
            answer = input(f'{cert_file_path}文件已经存在，必须删除才能继续,是否删除?(y/n)')
        if answer in ('y', 'Y'):
            os.remove(cert_file_path)
        else:
            raise Error(f'无法产生镜像服务器的出站设置sql')

    #
    # 检查主加密密钥是否存在
    if exists_master_key(connect_mirror):
        print(f'镜像服务器主加密密钥已存在')
    else:
        master_key_password = input(f'输入主加密密钥的口令:')
    endpoint_name = 'Endpoint_Mirroring'
    #
    # 镜像服务器的端点端口必须与主服务器的端点端口不同
    port = 5023
    #
    # 产生实际的sql文件

    #
    # 产生镜像服务器出站sql
    mirror_outbound_sql = get_outbound_sql('镜像', '主', master_key_password=master_key_password,
                                           cert_name=cert_name, endpoint_name=endpoint_name, port=port, cert_file_path=cert_file_path)
    #
    # 写入文件
    with open('step2_mirror_outbound.sql', 'w') as tmp_file:
        tmp_file.write(mirror_outbound_sql)

    print(
        f'镜像服务器出站sql文件已产生,请以sa用户登录镜像服务器数据库实例，执行“step2_mirror_outbound.sql”,然后将产生的{cert_file_path}拷贝到主服务器的{cert_dir_path}下，并在镜像服务器的防火墙允许{port}端口，再执行下面的步骤3')

    return cert_file_path


def master_inbound_setup(connect_master: ConnectionHelper, master_cert_file: str):
    """
    主服务器的入站设置
    """
    target_name = '主'
    counter_name = '镜像'
    cert_name = 'mirror_cert'
    user = 'mirror_user'
    login = 'mirror_login'

    login_password = input(f'请输入镜像用户登录({login})的密码:')

    cert_dir_path = 'c:\cer'
    cert_file_path = f'{cert_name}.cer'

    endpoint_name = 'Endpoint_Mirroring'
    #
    # 产生主服务器入站sql
    master_inbound_sql = get_inbound_sql(target_name=target_name,
                                         counter_name=counter_name, cert_name=cert_name, user=user, login=login, login_password=login_password, cert_dir_path=cert_dir_path, cert_file_path=cert_file_path, endpoint_name=endpoint_name
                                         )
    #
    # 写入文件
    with open('step3_master_inbound.sql', 'w') as tmp_file:
        tmp_file.write(master_inbound_sql)

    print(f'主服务器入站sql文件已产生,请以sa用户登录主服务器数据库实例，执行"step3_master_inbound.sql"，再执行下面的步骤')


def mirror_inbound_setup(connect_mirror: ConnectionHelper, mirror_cert_file: str):
    """
    镜像服务器的入站设置
    """
    target_name = '镜像'
    counter_name = '主'
    cert_name = 'master_cert'
    user = 'master_user'
    login = 'master_login'

    login_password = input(f'请输入镜像用户登录({login})的密码:')

    cert_dir_path = 'c:\cer'
    cert_file_path = f'{cert_name}.cer'

    endpoint_name = 'Endpoint_Mirroring'
    #
    # 产生主服务器入站sql
    mirror_inbound_sql = get_inbound_sql(target_name=target_name,
                                         counter_name=counter_name, cert_name=cert_name, user=user, login=login, login_password=login_password, cert_dir_path=cert_dir_path, cert_file_path=cert_file_path, endpoint_name=endpoint_name
                                         )
    #
    # 写入文件
    with open('step4_mirror_inbound.sql', 'w') as tmp_file:
        tmp_file.write(mirror_inbound_sql)

    print(f'镜像服务器入站sql文件已产生,请以sa用户登录镜像服务器数据库实例，执行"step4_mirror_inbound.sql"，再执行下面的步骤')


def exists_master_key(connect: ConnectionHelper) -> bool:
    """
    检查主加密密钥是否存在
    """
    cur = connect.connection.cursor()
    try:
        cur.execute(
            """SELECT COUNT(name) FROM sys.symmetric_keys WHERE name='##MS_DatabaseMasterKey##'""")
        result = cur.fetchone()
        if result[0] == 0:
            return False
        else:
            return True

    finally:
        cur.close()


def create_master_key(connect: ConnectionHelper):
    """
    产生主加密密钥
    """
    cur = connect.connection.cursor()
    try:
        password = input(f'输入主加密密钥的口令:')
        cur.execute(
            f"""CREATE MASTER KEY ENCRYPTION BY PASSWORD = '{password}'""")
        connect.connection.commit()
    finally:
        cur.close()


def exists_cert(connect: ConnectionHelper, cert_name: str):
    """
    检查证书是否存在
    """
    cur = connect.connection.cursor()
    try:
        cur.execute(
            f"""SELECT COUNT(name) FROM sys.certificates WHERE name='{cert_name}'""")
        result = cur.fetchone()
        if result[0] == 0:
            return False
        else:
            return True

    finally:
        cur.close()


def create_cert(connect: ConnectionHelper, cert_name: str):
    """
    产生证书
    """
    cur = connect.connection.cursor()
    try:
        #
        # timedelta就是没有以年为单位的:<
        expiry_date = datetime.datetime.now()+datetime.timedelta(weeks=99*50)
        cur.execute(f"""CREATE CERTIFICATE {cert_name}  WITH SUBJECT = '服务器证书{cert_name}',   
         EXPIRY_DATE = '{expiry_date.month}/{expiry_date.day}/{expiry_date.year}'; """)
        connect.connection.commit()
    finally:
        cur.close()


def exists_endpoint(connect: ConnectionHelper, endpoint_name: str):
    """
    检查数据库镜像端点是否存在
    """
    cur = connect.connection.cursor()
    try:
        cur.execute(
            f"""SELECT COUNT(name) FROM sys.endpoints WHERE name='Endpoint_Mirroring'""")
        result = cur.fetchone()
        if result[0] == 0:
            return False
        else:
            return True

    finally:
        cur.close()


def alter_endpoint(connect: ConnectionHelper, endpoint_name: str, listener_port: int, cert_name: str):
    """
    修改数据库镜像端点
    """
    cur = connect.connection.cursor()
    try:
        cur.execute(f"""ALTER ENDPOINT {endpoint_name}  
         STATE = STARTED  
         AS TCP (  
            LISTENER_PORT={listener_port}
            , LISTENER_IP = ALL  
         )   
         FOR DATABASE_MIRRORING (   
            -- 使用证书加密端点
            AUTHENTICATION = CERTIFICATE {cert_name}  
            , ENCRYPTION = REQUIRED ALGORITHM AES  
            , ROLE = ALL  
         );""")
        connect.connection.commit()
    finally:
        cur.close()


def create_endpoint(connect: ConnectionHelper, endpoint_name: str, listener_port: int, cert_name: str):
    """
    产生数据库镜像端点
    """
    cur = connect.connection.cursor()
    try:
        sql = f"""CREATE ENDPOINT {endpoint_name}  STATE = STARTED  AS TCP (LISTENER_PORT={listener_port},LISTENER_IP = ALL)  FOR DATABASE_MIRRORING (AUTHENTICATION = CERTIFICATE {cert_name} , ENCRYPTION = REQUIRED ALGORITHM AES , ROLE = ALL );"""
        print(sql)
        cur.execute(sql)
        connect.connection.commit()
    finally:
        cur.close()


def save_cert(connect: ConnectionHelper, cert_name: str, cert_file_path: str):
    """
    保存证书到硬盘
    """
    cur = connect.connection.cursor()
    try:
        cur.execute(
            f"""BACKUP CERTIFICATE {cert_name} TO FILE = '{cert_file_path}'; """)
        connect.connection.commit()
    finally:
        cur.close()


def exec_sql(connect: ConnectionHelper, sql: str):
    """
    执行sql
    """
    cur = connect.connection.cursor()
    try:
        cur.execute(sql)
        connect.connection.commit()
    finally:
        cur.close()


def drop_procedure(connect: ConnectionHelper, procedure_name: str):
    """
    删除一个存储过程
    """
    cur = connect.connection.cursor()
    try:
        cur.execute(f"""DROP PROCEDURE {procedure_name}""")
        connect.connection.commit()
    finally:
        cur.close()


def exists_procedure(connect: ConnectionHelper, procedure_name: str):
    """
    存储过程是否存在
    """
    cur = connect.connection.cursor()
    try:
        cur.execute(
            f"""SELECT  COUNT(id)  FROM   dbo.sysobjects   WHERE   id=object_id(N'{procedure_name}') AND OBJECTPROPERTY(id, N'IsProcedure')=1""")
        result = cur.fetchone()
        if result[0] == 0:
            return False
        else:
            return True

    finally:
        cur.close()


def get_databases(connect: ConnectionHelper) -> List[str]:
    """
    得到实例存在的数据库
    """
    cur = connect.connection.cursor()
    try:
        cur.execute(
            f"""SELECT * FROM sys.sysdatabases WHERE name not in ('master','tempdb','model','msdb') ORDER BY name""")
        result = cur.fetchall()
        databases = []
        for _ in result:
            databases.append(_[0])
        return databases
    finally:
        cur.close()

def mkdir_in_server(connect:ConnectionHelper,backup_dir:str):
    """
    在服务器上创建一个目录
    """
    exec_sql(connect=connect,sql=f"""EXEC sys.XP_CREATE_SUBDIR N'{backup_dir}'""")


def get_outbound_sql(target_name: str, counter_name: str, master_key_password, cert_name, endpoint_name, port, cert_file_path):
    """
    得到出站设置sql文本
    """
    #
    # 过期时间设置为遥远的未来
    expiry_date = '2099-01-01'
    env = Environment(loader=PackageLoader('helper'),
                      autoescape=select_autoescape())
    template = env.get_template("outbound.template.sql.j2")
    return template.render(target_name=target_name, counter_name=counter_name, master_key_password=master_key_password, cert_name=cert_name, expiry_date=expiry_date, endpoint_name=endpoint_name, port=port, cert_file_path=cert_file_path)


def get_inbound_sql(target_name: str, counter_name: str, cert_name: str, user: str, login: str, login_password: str, cert_dir_path: str, cert_file_path: str, endpoint_name: str):
    """
    得到入站设置sql文本
    """
    env = Environment(loader=PackageLoader('helper'),
                      autoescape=select_autoescape())
    template = env.get_template("inbound.template.sql.j2")
    return template.render(target_name=target_name,
                           counter_name=counter_name,
                           cert_name=cert_name,
                           user=user,
                           login=login,
                           login_password=login_password,
                           cert_dir_path=cert_dir_path,
                           cert_file_path=cert_file_path,
                           endpoint_name=endpoint_name)


def get_stop_mirroring_sql(target_name: str,backup_dir:str, databases: List[str]):
    """
    得到停止镜像收缩数据库的sql文本
    """
    env = Environment(loader=PackageLoader('helper'),
                      autoescape=select_autoescape())
    template = env.get_template("stop_mirroring.sql.j2")
    return template.render(target_name=target_name,
                            backup_dir=backup_dir,
                           databases=databases)


def stop_master_mirroring(target_name: str,backup_dir:str, databases: List[str]):
    """
    停止主服务器的指定数据库的镜像
    """
    sql = get_stop_mirroring_sql(target_name,backup_dir, databases)
    with open('stop_master_mirroing.sql', 'w') as tmp_file:
        tmp_file.write(sql)
        print(f'主服务器的停止镜像SQL脚本"stop_master_mirroring.sql"已产生,以管理员账号登录执行该脚本')
